from typing import Any, Literal, Type

from diffusers import (
    DDIMScheduler,
    DDPMScheduler,
    DEISMultistepScheduler,
    DPMSolverMultistepScheduler,
    DPMSolverSDEScheduler,
    DPMSolverSinglestepScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    HeunDiscreteScheduler,
    KDPM2AncestralDiscreteScheduler,
    KDPM2DiscreteScheduler,
    LCMScheduler,
    LMSDiscreteScheduler,
    PNDMScheduler,
    TCDScheduler,
    UniPCMultistepScheduler,
)
from diffusers.schedulers.scheduling_utils import SchedulerMixin

# TODO: add dpmpp_3s/dpmpp_3s_k when fix released
# https://github.com/huggingface/diffusers/issues/9007

SCHEDULER_NAME_VALUES = Literal[
    "ddim",
    "ddpm",
    "deis",
    "deis_k",
    "lms",
    "lms_k",
    "pndm",
    "heun",
    "heun_k",
    "euler",
    "euler_k",
    "euler_a",
    "kdpm_2",
    "kdpm_2_k",
    "kdpm_2_a",
    "kdpm_2_a_k",
    "dpmpp_2s",
    "dpmpp_2s_k",
    "dpmpp_2m",
    "dpmpp_2m_k",
    "dpmpp_2m_sde",
    "dpmpp_2m_sde_k",
    "dpmpp_3m",
    "dpmpp_3m_k",
    "dpmpp_sde",
    "dpmpp_sde_k",
    "unipc",
    "unipc_k",
    "lcm",
    "tcd",
]

SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str, Any]]] = {
    "ddim": (DDIMScheduler, {}),
    "ddpm": (DDPMScheduler, {}),
    "deis": (DEISMultistepScheduler, {"use_karras_sigmas": False}),
    "deis_k": (DEISMultistepScheduler, {"use_karras_sigmas": True}),
    "lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}),
    "lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
    "pndm": (PNDMScheduler, {}),
    "heun": (HeunDiscreteScheduler, {"use_karras_sigmas": False}),
    "heun_k": (HeunDiscreteScheduler, {"use_karras_sigmas": True}),
    "euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
    "euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
    "euler_a": (EulerAncestralDiscreteScheduler, {}),
    "kdpm_2": (KDPM2DiscreteScheduler, {"use_karras_sigmas": False}),
    "kdpm_2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
    "kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": False}),
    "kdpm_2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
    "dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False, "solver_order": 2}),
    "dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, "solver_order": 2}),
    "dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "solver_order": 2}),
    "dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2}),
    "dpmpp_2m_sde": (
        DPMSolverMultistepScheduler,
        {"use_karras_sigmas": False, "solver_order": 2, "algorithm_type": "sde-dpmsolver++"},
    ),
    "dpmpp_2m_sde_k": (
        DPMSolverMultistepScheduler,
        {"use_karras_sigmas": True, "solver_order": 2, "algorithm_type": "sde-dpmsolver++"},
    ),
    "dpmpp_3m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "solver_order": 3}),
    "dpmpp_3m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 3}),
    "dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}),
    "dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
    "unipc": (UniPCMultistepScheduler, {"use_karras_sigmas": False, "cpu_only": True}),
    "unipc_k": (UniPCMultistepScheduler, {"use_karras_sigmas": True, "cpu_only": True}),
    "lcm": (LCMScheduler, {}),
    "tcd": (TCDScheduler, {}),
}
